{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VnA4LuxbTyKb",
        "outputId": "d71b5b4d-f9b8-4205-b08a-6864cc9931a2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "nvcc: NVIDIA (R) Cuda compiler driver\n",
            "Copyright (c) 2005-2020 NVIDIA Corporation\n",
            "Built on Mon_Oct_12_20:09:46_PDT_2020\n",
            "Cuda compilation tools, release 11.1, V11.1.105\n",
            "Build cuda_11.1.TC455_06.29190527_0\n"
          ]
        }
      ],
      "source": [
        "!nvcc --version"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DtPcc6kJUE50"
      },
      "outputs": [],
      "source": [
        "#!pip install --upgrade \"jax[cuda]\" -f https://storage.googleapis.com/jax-releases/jax_releases.html"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sI5-pu0SUN_n",
        "outputId": "df12f76f-7123-45ad-bbc7-5dc2b2199a5e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sun Jun 26 12:51:39 2022       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   39C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        }
      ],
      "source": [
        "gpu_info = !nvidia-smi\n",
        "gpu_info = '\\n'.join(gpu_info)\n",
        "if gpu_info.find('failed') >= 0:\n",
        "  print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n",
        "  print('and then re-execute this cell.')\n",
        "else:\n",
        "  print(gpu_info)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uRNTcpf8Ui_W",
        "outputId": "7cc7a135-8c45-4448-89a0-bd6fa800e467"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "gpu\n"
          ]
        }
      ],
      "source": [
        "from jax.lib import xla_bridge\n",
        "print(xla_bridge.get_backend().platform)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ix8xq__wV7cX"
      },
      "outputs": [],
      "source": [
        "import jax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JOg7ciNvUnYP",
        "outputId": "55240124-c1b4-47dc-dd11-8ba3e32cbf79"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[GpuDevice(id=0, process_index=0)]"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "jax.devices()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "u0Jwogb9Z6LD",
        "outputId": "3d916f3f-d25d-4071-b312-98d939791831"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[CpuDevice(id=0)]"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ],
      "source": [
        "jax.devices(\"cpu\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8QVl4uEhZ9EC",
        "outputId": "b11d94b0-18f8-4681-93eb-624dd735db5f"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[GpuDevice(id=0, process_index=0)]"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "source": [
        "jax.devices(\"gpu\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mpaNmoJaUquI"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import jax.numpy as jnp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HGIhumiobfA1"
      },
      "outputs": [],
      "source": [
        "# a function with some amount of calculations\n",
        "def f(x):\n",
        "  y1 = x + x*x + 3\n",
        "  y2 = x*x + x*x.T\n",
        "  return y1*y2\n",
        "\n",
        "# generate some random data\n",
        "x = np.random.randn(3000, 3000).astype('float32')\n",
        "jax_x_gpu = jax.device_put(jnp.array(x), jax.devices('gpu')[0])\n",
        "jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])\n",
        "\n",
        "# compile function to CPU and GPU backends with JAX\n",
        "jax_f_cpu = jax.jit(f, backend='cpu')\n",
        "jax_f_gpu = jax.jit(f, backend='gpu')\n",
        "\n",
        "# warm-up\n",
        "jax_f_cpu(jax_x_cpu)\n",
        "jax_f_gpu(jax_x_gpu);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EeAFVSNNeL8b",
        "outputId": "a46419af-9f1a-4864-fcc0-e720ab2225bd"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "100 loops, best of 5: 49.8 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n100 f(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GUULya4NfWiy",
        "outputId": "32d51f98-119d-4436-f7a2-faed0bcf6c81"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "100 loops, best of 5: 59.5 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n100 f(jax_x_cpu).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vzlZHhyxfDQC",
        "outputId": "b1aa4973-85e9-4a22-9737-89fd692e4fc3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "100 loops, best of 5: 10.5 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n100 jax_f_cpu(jax_x_cpu).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%timeit -n100 f(jax_x_gpu).block_until_ready()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Q3MqYbe6F1fB",
        "outputId": "d3f226a3-dab0-44f3-c2a8-9f3efb7abfc5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "100 loops, best of 5: 1.87 ms per loop\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0lZseCFDfQG_",
        "outputId": "35414270-957a-4e05-8007-fd9c2da9908b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "100 loops, best of 5: 649 µs per loop\n"
          ]
        }
      ],
      "source": [
        "%timeit -n100 jax_f_gpu(jax_x_gpu).block_until_ready()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax_x_cpu.device()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uQwtrX21EeyZ",
        "outputId": "c06004c4-042f-477e-e9f5-176318fddfcf"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "CpuDevice(id=0)"
            ]
          },
          "metadata": {},
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax_x_gpu.device()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7ni0L_Z8GT0A",
        "outputId": "7e2e4f0a-7813-4b4d-9ca9-b8c73ed1fead"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "GpuDevice(id=0, process_index=0)"
            ]
          },
          "metadata": {},
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "jJTdqcZAEiwo"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "JAX_in_Action_Chapter_1_JAX_speedup.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}